之前有講到如何儲存、與載入Layer API的模型
今天就是來載入別人處理好的模型
來做影像分類
在這裡用的是Google的研究團隊提出的CNN模型,MobileNet,在目前有的模型中,例如說VGG,RestNet等等的神經網路層深度小很多,當然也就相對沒有那麼準確,不過就要看要用在什麼場合,MobileNet顧名思義就是希望能夠在計算能力較差的行動裝置內應用,所以核心思想就是減化計算量。
那MobileNet是怎麼減化計算量的呢?
先回想一下前幾天做的MNIST手寫數字辨識,裡面最大的部份發生在卷積的計算中,如果要減少計算量,基本上就是從這邊下手
例如像奇異值分解(SVD),把影像矩陣拆成兩個較小矩陣的乘積,那麼在全連接層的計算權重就會減少很多,不過要怎麼用Layer API做到這件事,好像有點想不太到
根據MobileNet論文的說法
MobileNet在減化計算量方面是透過將卷積計算改成一種叫Depthwise Separable Convolution(不知道怎麼翻譯比較好)的方式
一般的卷積計算
kernel_size* kernel_sizeinput_chaneloutput_chanelfeature_map_size feature_map_size
他只是把它分開來算
分成depthwise convolution 負責用kenal去濾資料
跟pointwise convolution 負責轉換
所以第一步depthwise convolution
kernel_size* kernel_sizeinput_chanel feature_map_size feature_map_size
第二步pointwise convolution再從input轉output
nput_chaneloutput_chanelfeature_map_size feature_map_size
稍微寫清楚一點
另K=kernel_size* kernel_size
另M= input_chanel
另N= output_chanel
另F= feature_map_size* feature_map_size
那原先的計算量就是KMNF
Depthwise Separable Convolution的計算量是KMF+MN*F
分開來的計算量/原先計算量
=(KMF+MNF)/(KMNF)
=(1/N)+(1/K)
因為MobileNet中設定的kernel_size=3
所以K=9
(1/K)=0.1111111
output_channel肯定不會只又1(因為要形成不同的特徵圖)
所以結果來說會快很多
以下是取自論文的比較表
可以看到參數少了很多(雖然每一層會變成兩層,但因為層數還是比別人少,所以總參數也是少很多)
而乘法與加法的計算量相對少了很多
精準度也沒差多少
PS使用的最佳化的函數是RMSprop
另外還有就是控制模型大小的超參數: width multiplier
會影響input_chanel、output_chanel的大小
可設置區間(0,1]之間,可以將計算量與參數數量降低
另外是控制特徵圖大小: resolution multiplier
常見224、192、160、128
以下是基本的使用範例
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>mobilenet</title>
<!-- TensorFlow.js Core -->
<script
src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.2.7/dist/tf.min.js">
</script>
<!-- TensorFlow.js mobilenet model -->
<script
src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet">
</script>
</head>
<body>
<img id="img" crossOrigin src="https://i.imgur.com/WRitB1p.jpg"
width=400 height=400 />
<p id="result"></p>
<script>
async function app() {
const mobilenet_model = await mobilenet.load();
const img = document.getElementById('img');
const result = await mobilenet_model.classify(img);
document.getElementById('result').innerHTML = result[0].className;
console.table(result);
}
app();
</script>
</body>
</html>